# Author: Ashley S. Lee
# Brown University
# Data Science Practice
# Computing & Information Services
# 20181128

import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.colors as colors

def discrete_cmap(N, base_cmap=None):
    """Create an N-bin discrete colormap from the specified input map"""

    # Note that if base_cmap is a string or None, you can simply do
    #    return plt.cm.get_cmap(base_cmap, N)
    # The following works for string, None, or a colormap instance:

    base = plt.cm.get_cmap(base_cmap)
    color_list = base(np.linspace(0, 1, N))
    cmap_name = base.name + str(N)
    return color_list, colors.LinearSegmentedColormap.from_list(cmap_name, color_list, N)


def fraction_derived_corpus(dataframe):
    """compute fraction derived corpus per year"""
    years = dataframe.groupby(0).size().reset_index(name = "count")
    years_percs = years_all.merge(years, how='outer', on=0).fillna(value=0)
    years_percs['fraction'] = years_percs['count_y']/years_percs['count_x']
    years_percs.columns = ["year", "num_debates", "num_derived", "fraction"]
    return(years, years_percs)


def threshold_debates(percent_similarity, distance_array, metric):
    """get debates above some threshold similarity from a pre-computed distance matrix"""
    # one column for each seed
    a1 = distance_array[:,0]
    a2 = distance_array[:,1]
    a3 = distance_array[:,2]
    a4 = distance_array[:,3]

    # get index of 1% cutoff in numpy array
    threshold = round(len(distance_array) * percent_similarity)
    idx1 = np.argpartition(a1, threshold)
    idx2 = np.argpartition(a2, threshold)
    idx3 = np.argpartition(a3, threshold)
    idx4 = np.argpartition(a4, threshold)

    # combine indexes and get uniques
    a = np.array(idx1[:threshold])
    b = np.array(idx2[:threshold])
    c = np.array(idx3[:threshold])
    d = np.array(idx4[:threshold])
    idxs = np.concatenate((a,b,c,d),0)
    idx_set = np.unique(idxs)

    # get derived corpus bill titles
    dc = metadata.ix[idx_set,:]
    
    # write to tsv
    dc.to_csv('./output/titles_derived1_{}.txt'.format(metric, percent_similarity), sep = '\t', header = False, index = False)

    return(dc)


# create output folder in local filesystem
os.system('mkdir output')

# -------------------------- PLOT 1 -------------------------- #

cm, cm2 = discrete_cmap(4, 'viridis')

# load kld1, kld2, jsd distance arrays
kld1 = np.load("./KLD1.npy")
jsd = np.load("./JSD.npy")

# plot each seed as a different colored histogram
plt.hist(kld1, bins=100, histtype='step', stacked=False, fill=True, color=cm, alpha=0.5, 
         label=['napier','devon','richmond','bessborough'])
plt.grid(True)
plt.xlim([0,12])
plt.title('Kullback-Leibler 1 Histograms', y = 1.2)
plt.xlabel('Divergence')
plt.ylabel('Frequency')
plt.legend(bbox_to_anchor=(0., 1.02, 1., .102), loc = 3, ncol = 2, mode = 'expand', 
           borderaxespad = 0.)
plt.savefig('./output/jca_KL1_seedhists2_viridis.jpg', 
            bbox_inches='tight', dpi=300)
plt.show()

# -------------------------- PLOT 2 -------------------------- #

# Add bill titles and dates and IDs to the distance matrices
with open("./long_bills_stemmed_metadata.tsv", 'r') as f:
    metadata = pd.read_csv(f, sep = '\t', header = None)
    
# number debates per year
years_all = metadata.groupby(0).size().reset_index(name = "count")

# Kullback Leibler
dc1kld1 = threshold_debates(0.01, kld1, 'kld1')

# Jensen Shannon
dc1jsd = threshold_debates(0.01, jsd, 'jsd')

# kld1
years1kld1, years_percs1kld1 = fraction_derived_corpus(dc1kld1)

plt.bar(years_percs1kld1["year"], years_percs1kld1['fraction'], align = 'center', width = 1, color='b', alpha=1, label='1% most similar')
plt.title("Fraction Derived Corpus (KLD1)")
plt.ylabel("fr. debates")
plt.xlabel("years")
plt.legend(bbox_to_anchor=(0, 1), loc='upper left', ncol=1, fontsize = 'xx-small')
plt.grid(True)
plt.savefig("./output/jca_fr_derived_year_kld1.jpg", dpi=300)
plt.show()

# jsd
years1jsd, years_percs1jsd = fraction_derived_corpus(dc1jsd)

plt.bar(years_percs1jsd["year"], years_percs1jsd['fraction'], align = 'center', width = 1, color='b', alpha=1, label='1% most similar')
plt.title("Fraction Derived Corpus (JSD)")
plt.ylabel("fr. debates")
plt.xlabel("years")
plt.legend(bbox_to_anchor=(0, 1), loc='upper left', ncol=1, fontsize = 'xx-small')
plt.grid(True)
plt.savefig("./output/jca_fr_derived_year_jsd.jpg", dpi=300)
plt.show()